﻿/*
 * Licensed to Apereo under one or more contributor license
 * agreements. See the NOTICE file distributed with this work
 * for additional information regarding copyright ownership.
 * Apereo licenses this file to you under the Apache License,
 * Version 2.0 (the "License"); you may not use this file
 * except in compliance with the License. You may obtain a
 * copy of the License at:
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on
 * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#if NET40 || NET45
using System;
using System.Collections.Generic;
using DotNetCasClient.Logging;
using DotNetCasClient.Utils;

namespace DotNetCasClient.State
{
    /// <summary>
    /// An IServiceTicketManager implementation that relies on the System.Runtime.Caching caching model for ticket storage.  Generally this implies that the ticket storage is maintained locally on the web server (either in memory or on disk).  A limitation of this model is that it will not support clustered, load balanced, or round-robin style configurations.
    /// </summary>
    /// <author>Jason Kanaris</author>
    public sealed class MemoryCacheServiceTicketManager : IServiceTicketManager
    {
        /// <summary>
        /// This prefix is prepended to CAS Service Ticket as the key to the cache.
        /// </summary>
        private const string CACHE_TICKET_KEY_PREFIX = "CasTicket::";

        private static readonly Logger securityLogger = new Logger(Category.Security);

        /// <summary>
        /// Parameterless constructor needed for Reflection to instantiate it properly        
        /// </summary>
        public MemoryCacheServiceTicketManager() { }

        /// <summary>
        /// Performs initialization of the MemoryCacheServiceTicketManager
        /// </summary>
        public void Initialize()
        {
            // Do nothing
        }

        /// <summary>
        /// Removes expired entries from the ticket store
        /// </summary>
        public void RemoveExpiredTickets()
        {
            // No-op.  The System.Runtime.Caching.ObjectCache provider removes expired entries automatically.
        }

        /// <summary>
        /// Retrieve a CasAuthenticationTicket from the ticket store by it's CAS Service Ticket
        /// </summary>
        /// <param name="serviceTicket">The service ticket generated by the CAS server</param>
        /// <returns>The CasAuthenticationTicket or null if no matching ticket is found</returns>
        /// <exception cref="ArgumentNullException">serviceTicket is null</exception>
        /// <exception cref="ArgumentException">serviceTicket is empty</exception>
        public CasAuthenticationTicket GetTicket(string serviceTicket)
        {
            CommonUtils.AssertNotNullOrEmpty(serviceTicket, "serviceTicket parameter cannot be null or empty.");

            string key = GetTicketKey(serviceTicket);
            if (MemoryCacheManager.Instance.Get(key) != null)
            {
                CasAuthenticationTicket result = MemoryCacheManager.Instance.Get(key) as CasAuthenticationTicket;
                return result;
            }
            return null;
        }

        /// <summary>
        /// Inserts a CasAuthenticationTicket to the ticket store with a corresponding 
        /// ticket expiration date.
        /// </summary>
        /// <param name="casAuthenticationTicket">The CasAuthenticationTicket to insert</param>
        /// <param name="expiration">The date and time at which the ticket expires</param>
        /// <exception cref="ArgumentNullException">casAuthenticationTicket is null</exception>
        public void InsertTicket(CasAuthenticationTicket casAuthenticationTicket, DateTime expiration)
        {
            CommonUtils.AssertNotNull(casAuthenticationTicket, "casAuthenticationTicket parameter cannot be null.");

            // Don't enforce sliding expiration on the cache entry.  Sliding expiration is handled by the HttpModule.
            MemoryCacheManager.Instance.Set(GetTicketKey(casAuthenticationTicket.ServiceTicket), casAuthenticationTicket, expiration);
        }

        /// <summary>
        /// Updates the expiration date and time for an existing ticket.  If the ticket does not exist in the ticket store, just return (do not throw an exception).
        /// </summary>
        /// <param name="casAuthenticationTicket">The CasAuthenticationTicket to insert</param>
        /// <param name="newExpiration">The new expiration date and time</param>
        /// <exception cref="ArgumentNullException">casAuthenticationTicket is null</exception>
        public void UpdateTicketExpiration(CasAuthenticationTicket casAuthenticationTicket, DateTime newExpiration)
        {
            CommonUtils.AssertNotNull(casAuthenticationTicket, "casAuthenticationTicket parameter cannot be null.");

            InsertTicket(casAuthenticationTicket, newExpiration);
        }

        /// <summary>
        /// Removes the ticket from the collection if it exists.  If the ticket does not exist in the ticket store, just return (do not throw an exception).
        /// </summary>
        /// <param name="serviceTicket">The ticket to remove from the ticket store</param>
        /// <exception cref="ArgumentNullException">serviceTicket is null</exception>
        /// <exception cref="ArgumentException">serviceTicket is empty</exception>
        public void RevokeTicket(string serviceTicket)
        {
            CommonUtils.AssertNotNullOrEmpty(serviceTicket, "serviceTicket parameter cannot be null or empty.");

            string key = GetTicketKey(serviceTicket);
            if (MemoryCacheManager.Instance.Get(key) != null)
            {
                CasAuthenticationTicket ticket = MemoryCacheManager.Instance.Get(key) as CasAuthenticationTicket;
                if (ticket != null)
                {
                    if (MemoryCacheManager.Instance.Get(key) != null)
                    {
                        MemoryCacheManager.Instance.Remove(key);
                    }
                }
            }
        }

        /// <summary>
        /// Indicates whether or not the ticket store contains the supplied serviceTicket
        /// </summary>
        /// <param name="serviceTicket">The service ticket to check for</param>
        /// <returns>True if the ticket is contained in the store</returns>
        /// <exception cref="ArgumentNullException">serviceTicket is null</exception>
        /// <exception cref="ArgumentException">serviceTicket is empty</exception>
        public bool ContainsTicket(string serviceTicket)
        {
            CommonUtils.AssertNotNullOrEmpty(serviceTicket, "serviceTicket parameter cannot be null or empty.");

            string key = GetTicketKey(serviceTicket);
            if (MemoryCacheManager.Instance.Get(key) != null)
            {
                CasAuthenticationTicket currentAuthTicket = MemoryCacheManager.Instance.Get(key) as CasAuthenticationTicket;
                if (currentAuthTicket != null)
                {
                    if (currentAuthTicket.ServiceTicket == serviceTicket)
                    {
                        return true;
                    }
                }
            }

            return false;
        }

        /// <summary>
        /// Revoke all tickets corresponding to the supplied NetId.
        /// </summary>
        /// <param name="netId">The NetId to revoke tickets for</param>
        /// <exception cref="ArgumentNullException">The netId supplied is null</exception>
        /// <exception cref="ArgumentException">The netId supplied is empty</exception>
        public void RevokeUserTickets(string netId)
        {
            CommonUtils.AssertNotNullOrEmpty(netId, "netId parameter cannot be null or empty.");

            IEnumerable<CasAuthenticationTicket> allTickets = GetAllTickets();
            foreach (CasAuthenticationTicket ticket in allTickets)
            {
                if (String.Compare(ticket.NetId, netId, true) == 0)
                {
                    RevokeTicket(ticket.ServiceTicket);
                }
            }
        }

        /// <summary>
        /// Retrieves all tickets in the ticket store that have not already expired.
        /// </summary>
        /// <returns>An enumerable collection of CasAuthenticationTickets</returns>
        public IEnumerable<CasAuthenticationTicket> GetAllTickets()
        {
            var cacheItems = MemoryCacheManager.Instance.GetAll();
            foreach (var cacheItem in cacheItems)
            {
                if (cacheItem.Key != null && cacheItem.Key.StartsWith(CACHE_TICKET_KEY_PREFIX))
                {
                    CasAuthenticationTicket currentTicket = cacheItem.Value as CasAuthenticationTicket;
                    if (currentTicket != null)
                    {
                        yield return currentTicket;
                    }
                }
            }
        }

        /// <summary>
        /// Retrieves all non-expired tickets in the ticket store associated with the 
        /// netId supplied.
        /// </summary>
        /// <param name="netId">The NetId to search the collection for</param>
        /// <returns>An enumerable collection of CasAuthenticationTickets</returns>
        /// <exception cref="ArgumentNullException">netId is null</exception>
        /// <exception cref="ArgumentException">netId is empty</exception>
        public IEnumerable<CasAuthenticationTicket> GetUserTickets(string netId)
        {
            CommonUtils.AssertNotNullOrEmpty(netId, "netId parameter cannot be null or empty.");

            var cacheItems = MemoryCacheManager.Instance.GetAll();
            foreach (var cacheItem in cacheItems)
            {
                if (cacheItem.Key != null && cacheItem.Key.StartsWith(CACHE_TICKET_KEY_PREFIX))
                {
                    CasAuthenticationTicket currentTicket = cacheItem.Value as CasAuthenticationTicket;
                    if (currentTicket != null && String.Compare(currentTicket.NetId, netId, true) == 0)
                    {
                        yield return currentTicket;
                    }
                }
            }
        }

        /// <summary>
        /// Retrieves all CAS Service Tickets in the ticket store that have not already
        /// expired.
        /// </summary>
        /// <returns>An enumerable collection of service tickets</returns>
        public IEnumerable<string> GetAllServiceTickets()
        {
            var cacheItems = MemoryCacheManager.Instance.GetAll();
            foreach (var cacheItem in cacheItems)
            {
                if (cacheItem.Key != null && cacheItem.Key.StartsWith(CACHE_TICKET_KEY_PREFIX))
                {
                    CasAuthenticationTicket currentAuthTicket = cacheItem.Value as CasAuthenticationTicket;
                    if (currentAuthTicket != null)
                    {
                        yield return currentAuthTicket.ServiceTicket;
                    }
                }
            }
        }

        /// <summary>
        /// Retrieves all non-expired CAS Service Tickets in the ticket store associated 
        /// with the netId supplied.
        /// </summary>
        /// <param name="netId">The netId to search the collection for</param>
        /// <returns>An enumerable collection of service tickets</returns>
        /// <exception cref="ArgumentNullException">netId is null</exception>
        /// <exception cref="ArgumentException">netId is empty</exception>
        public IEnumerable<string> GetUserServiceTickets(string netId)
        {
            CommonUtils.AssertNotNullOrEmpty(netId, "netId parameter cannot be null or empty.");

            var cacheItems = MemoryCacheManager.Instance.GetAll();
            foreach (var cacheItem in cacheItems)
            {
                if (cacheItem.Key != null && cacheItem.Key.StartsWith(CACHE_TICKET_KEY_PREFIX))
                {
                    CasAuthenticationTicket currentAuthTicket = cacheItem.Value as CasAuthenticationTicket;
                    if (currentAuthTicket != null && String.Compare(currentAuthTicket.NetId, netId, true) == 0)
                    {
                        yield return currentAuthTicket.ServiceTicket;
                    }
                }
            }
        }

        /// <summary>
        /// Retrieves a list of all users that have non-expired CAS authentication 
        /// tickets.
        /// </summary>
        /// <returns>An enumerable collection of NetId's</returns>
        public IEnumerable<string> GetAllTicketedUsers()
        {
            List<string> result = new List<string>();
            IEnumerable<CasAuthenticationTicket> tickets = GetAllTickets();
            foreach (CasAuthenticationTicket ticket in tickets)
            {
                if (!result.Contains(ticket.NetId))
                {
                    result.Add(ticket.NetId);
                }
            }
            return result.ToArray();
        }

        /// <summary>
        /// Verify that the supplied casAuthenticationTicket exists in the ticket store
        /// </summary>
        /// <param name="casAuthenticationTicket">The casAuthenticationTicket to verify</param>
        /// <returns>
        /// True if the ticket exists in the ticket store and the properties of that 
        /// ticket match the properties of the ticket in the ticket store.
        /// </returns>
        public bool VerifyClientTicket(CasAuthenticationTicket casAuthenticationTicket)
        {
            CommonUtils.AssertNotNull(casAuthenticationTicket, "casAuthenticationTicket parameter cannot be null.");

            string incomingServiceTicket = casAuthenticationTicket.ServiceTicket;
            CasAuthenticationTicket cacheAuthTicket = GetTicket(incomingServiceTicket);
            if (cacheAuthTicket != null)
            {
                string cacheServiceTicket = cacheAuthTicket.ServiceTicket;
                if (cacheServiceTicket == incomingServiceTicket)
                {
                    if (String.Compare(cacheAuthTicket.NetId, casAuthenticationTicket.NetId, true) != 0)
                    {
                        securityLogger.Info("Username {0} in ticket {1} does not match cached value.",
                            casAuthenticationTicket.NetId, incomingServiceTicket);
                        return false;
                    }

                    if (String.Compare(cacheAuthTicket.Assertion.PrincipalName, casAuthenticationTicket.Assertion.PrincipalName, true) != 0)
                    {
                        securityLogger.Info("Principal name {0} in assertion of ticket {1} does not match cached value.",
                            casAuthenticationTicket.NetId, casAuthenticationTicket.Assertion.PrincipalName);
                        return false;
                    }

                    return true;
                }
            }
            else
            {
                securityLogger.Info("Ticket {0} not found in cache.  Never existed, expired, or removed via single sign out",
                    incomingServiceTicket);
                return false;
            }
            return false;
        }

        /// <summary>
        /// Converts a CAS Service Ticket to its corresponding key in the
        /// ticket manager store (cache provider). 
        /// </summary>
        /// <param name="serviceTicket">
        /// The CAS Service ticket to convert.
        /// </param>
        /// <returns>
        /// The cache key associated with the corresponding 
        /// service ticket
        /// </returns>
        /// <exception cref="ArgumentNullException">serviceTicket is null</exception>
        /// <exception cref="ArgumentException">serviceTicket is empty</exception>
        private static string GetTicketKey(string serviceTicket)
        {
            CommonUtils.AssertNotNullOrEmpty(serviceTicket, "serviceTicket parameter cannot be null or empty.");

            return CACHE_TICKET_KEY_PREFIX + serviceTicket;
        }
    }
}
#endif